{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Copyright 2019 The Google Research Authors.\n",
    "\n",
    "Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "you may not use this file except in compliance with the License.\n",
    "You may obtain a copy of the License at\n",
    "\n",
    "     http://www.apache.org/licenses/LICENSE-2.0\n",
    "\n",
    "Unless required by applicable law or agreed to in writing, software\n",
    "distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "See the License for the specific language governing permissions and\n",
    "limitations under the License."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Understanding Black-box Model Predictions using RL-LIM\n",
    "\n",
    " * Jinsung Yoon, Sercan O Arik, Tomas Pfister, \"RL-LIM: Reinforcement Learning-based Locally Interpretable Modeling\", arXiv preprint arXiv:1909.12367 (2019) - https://arxiv.org/abs/1909.12367\n",
    " \n",
    "This notebook describes how to explain black-box models using \"Reinforcement Learning based Locally Interpretable Modeling (RL-LIM)\". \n",
    "\n",
    "RL-LIM is a state of the art locally interpretable modeling method. It is often challenging to develop a globally interpretable model that has the performance at the level of 'black-box' models. To go beyond the performance limitations, a promising direction is locally interpretable models, which explain a single prediction, instead of explaining the entire model. Methodologically, while a globally interpretable model fits a single inherently interpretable model (such as a linear model or a shallow decision tree) to the entire training set, locally interpretable models aim to fit an inherently interpretable model locally, i.e. for each instance individually, by distilling knowledge from a high performance black-box model.\n",
    "\n",
    "Such locally interpretable models are very useful for real-world AI deployments to provide succinct and human-like explanations to users. They can be used to identify systematic failure cases (e.g. by seeking common trends in input dependence for failure cases), detect biases (e.g. by quantifying feature importance for a particular variable), and provide actionable feedback to improve a model (e.g. understand failure cases and what training data to collect).\n",
    "\n",
    "You need:\n",
    "\n",
    "**Training / Probe / Testing sets** \n",
    " * If you don't have a probe set, you can construct it by splitting a small portion of training set, while keeping the rest of the training set for training purpose.\n",
    " * The training / probe / testing datasets you have should be saved under './repo/data_files/' directory, with the names: 'train.csv', 'probe.csv', and 'test.csv'.\n",
    " * In this notebook, we create 'train.csv', 'probe.csv', and 'test.csv' files from the Facebook Comment Volume dataset (https://archive.ics.uci.edu/ml/datasets/Facebook+Comment+Volume+Dataset) as an example."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##  Prerequisite\n",
    "\n",
    " * Download lightgbm package.\n",
    " * Clone https://github.com/google-research/google-research.git to the current directory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Installs additional packages\n",
    "import pip\n",
    "import IPython\n",
    "\n",
    "def import_or_install(package):\n",
    "    try:\n",
    "        __import__(package)\n",
    "    except ImportError:\n",
    "        pip.main(['install', package])\n",
    "        app = IPython.Application.instance()\n",
    "        app.kernel.do_shutdown(True)  \n",
    "        \n",
    "import_or_install('lightgbm')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from git import Repo\n",
    "\n",
    "# Current working directory\n",
    "repo_dir = os.getcwd() + '/repo'\n",
    "\n",
    "if not os.path.exists(repo_dir):\n",
    "    os.makedirs(repo_dir)\n",
    "\n",
    "# Clones github repository\n",
    "if not os.listdir(repo_dir):\n",
    "    git_url = \"https://github.com/google-research/google-research.git\"\n",
    "    Repo.clone_from(git_url, repo_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Necessary packages and function calls\n",
    "\n",
    " * ridge: Ridge regression model used as an interpretable model.\n",
    " * lightgbm: lightGBM model used as a black-box model.\n",
    " * load_facebook_data: Data loader for facebook comment volumn dataset.\n",
    " * preprocess_data: Data extraction and normalization.\n",
    " * rllim: RL-LIM class for training instance-wise weight estimator.\n",
    " * rllim_metrics: Evaluation metrics for the locally interpretable models in various metrics (overall performance and fidelity)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "from sklearn.linear_model import Ridge\n",
    "import lightgbm\n",
    "\n",
    "# Sets current directory\n",
    "os.chdir(repo_dir)\n",
    "\n",
    "from rllim.data_loading import load_facebook_data, preprocess_data\n",
    "from rllim import rllim\n",
    "from rllim.rllim_metrics import fidelity_metrics, overall_performance_metrics"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data loading\n",
    "\n",
    " * Load training, probe and testing datasets and save those datasets as train.csv, probe.csv, test.csv in './repo/data_files/' directory.\n",
    " * If you have your own 'train.csv', 'probe.csv', and 'test.csv' files, you can skip this example dataset construction portion and just put them under  './repo/data_files/' directory.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Finished data loading.\n"
     ]
    }
   ],
   "source": [
    "# The number of training and probe samples (we use 10% of the training set as the probe set). \n",
    "# Explicit testing set exists in facebook comment volume dataset\n",
    "dict_rate = dict()\n",
    "dict_rate['train'] = 0.9\n",
    "dict_rate['probe'] = 0.1\n",
    "\n",
    "# Random seed\n",
    "seed = 0\n",
    "\n",
    "# Loads data\n",
    "load_facebook_data(dict_rate, seed)\n",
    "\n",
    "print('Finished data loading.')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data preprocessing\n",
    "\n",
    " * Extract features and labels from train.csv, probe.csv, test.csv in './repo/data_files/' directory.\n",
    " * Normalize the features of training, probe, and testing sets."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/google/home/jinsungyoon/anaconda3/lib/python3.7/site-packages/sklearn/preprocessing/data.py:334: DataConversionWarning: Data with input dtype int64, float64 were all converted to float64 by MinMaxScaler.\n",
      "  return self.partial_fit(X, y)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Finished data preprocess.\n"
     ]
    }
   ],
   "source": [
    "# Normalization methods: either 'minmax' or 'standard'\n",
    "normalization = 'minmax' \n",
    "\n",
    "# Extracts features and labels, and then normalize features\n",
    "x_train, y_train, x_probe, y_probe, x_test, y_test, col_names = \\\n",
    "preprocess_data(normalization, 'train.csv', 'probe.csv', 'test.csv')\n",
    "\n",
    "print('Finished data preprocess.')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 0: Black-box model training\n",
    "\n",
    "This stage is the preliminary stage for RL-LIM. We train a black-box model (in this notebook, lightGBM) using the training datasets (x_train, y_train) to make a pre-trained black-box model. If you already have a saved pre-trained black-box model, you can skip this stage and retrieve the pre-trained black-box model into bb_model. You also need to specify whether the problem is regression or classification.\n",
    "\n",
    " * Note that the bb_model must have fit, predict (for regression) or predict_proba (for classification) as the methods."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Finished black-box model training.\n"
     ]
    }
   ],
   "source": [
    "# Problem specification\n",
    "problem = 'regression' # or 'classification'\n",
    "\n",
    "# Initializes black-box model\n",
    "if problem == 'regression':\n",
    "    bb_model = lightgbm.LGBMRegressor()\n",
    "elif problem == 'classification':\n",
    "    bb_model = lightgbm.LGBMClassifier()\n",
    "\n",
    "# Trains black-box model\n",
    "bb_model = bb_model.fit(x_train, y_train)\n",
    "\n",
    "print('Finished black-box model training.')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 1: Auxiliary dataset construction\n",
    "\n",
    "Using the pre-trained black-box model, we create auxiliary training (x_train, y_train_hat) and probe datasets (x_probe, y_probe_hat). These auxiliary datasets are used for instance weight estimator and locally interpretable model training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Finished auxiliary dataset construction.\n"
     ]
    }
   ],
   "source": [
    "# Constructs auxiliary datasets\n",
    "if problem == 'regression':\n",
    "    y_train_hat = bb_model.predict(x_train)\n",
    "    y_probe_hat = bb_model.predict(x_probe)\n",
    "elif problem == 'classification':\n",
    "    y_train_hat = bb_model.predict_proba(x_train)[:, 1]\n",
    "    y_probe_hat = bb_model.predict_proba(x_probe)[:, 1]\n",
    "    \n",
    "print('Finished auxiliary dataset construction.')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 2: Interpretable baseline training\n",
    "\n",
    "To improve the stability of the instance-wise weight estimator training, a baseline model is observed to be beneficial. We use a globally interpretable model (in this notebook, we use Ridge regression) optimized to replicate the predictions of the black-box model.\n",
    "\n",
    "1. **Input**: \n",
    " * Locally interpretable model: ridge regression (we can switch this to shallow tree). The model must have fit, predict (for regression) and predict_proba (for classification) as the subfunctions.\n",
    " \n",
    " \n",
    "2. **Output**:\n",
    " * Trained interpretable baseline model: function that tries to replicate the predictions of the black-box model using globally interpretable model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Finished interpretable baseline training.\n"
     ]
    }
   ],
   "source": [
    "# Define interpretable baseline model\n",
    "baseline = Ridge(alpha=1)\n",
    "\n",
    "# Trains interpretable baseline model\n",
    "baseline.fit(x_train, y_train_hat)\n",
    "\n",
    "print('Finished interpretable baseline training.')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 3: Train instance-wise weight estimator\n",
    "\n",
    "We train an instance-wise weight estimator using the auxiliary training (x_train, y_train_hat) and probe datasets (x_probe, y_probe_hat) using reinforcement learning.\n",
    "\n",
    "1. **Input**: \n",
    " * Network parameters: Set network parameters of instance-wise weight estimator.\n",
    " * Locally interpretable model: Ridge regression (we can switch this to shallow tree). The model must have fit, predict (for regression) or predict_proba (for classification) as the methods.\n",
    " \n",
    " \n",
    "2. **Output**:\n",
    " * Instancewise weight estimator: Function that uses auxiliary training set and a testing sample as inputs to estimate weights for each training sample to construct locally interpretable model for the testing sample."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "WARNING: The TensorFlow contrib module will not be included in TensorFlow 2.0.\n",
      "For more information, please see:\n",
      "  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md\n",
      "  * https://github.com/tensorflow/addons\n",
      "If you depend on functionality not listed there, please file an issue.\n",
      "\n",
      "WARNING:tensorflow:From /usr/local/google/home/jinsungyoon/anaconda3/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Colocations handled automatically by placer.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2000/2000 [36:02<00:00,  1.02it/s] \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Finished instance-wise weight estimator training.\n",
      "WARNING:tensorflow:From /usr/local/google/home/jinsungyoon/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/saver.py:1266: checkpoint_exists (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use standard file APIs to check for files with this prefix.\n",
      "INFO:tensorflow:Restoring parameters from ./tmp/model.ckpt\n",
      "INFO:tensorflow:Restoring parameters from ./tmp/model.ckpt\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1/1 [00:01<00:00,  1.11s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Finished instance-wise weight estimations, instance-wise predictions, and local explanations.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# Instance-wise weight estimator network parameters\n",
    "parameters = dict()\n",
    "parameters['hidden_dim'] = 100\n",
    "parameters['iterations'] = 2000\n",
    "parameters['num_layers'] = 5\n",
    "parameters['batch_size'] = 5000\n",
    "parameters['batch_size_inner'] = 10\n",
    "parameters['lambda'] = 1.0\n",
    "\n",
    "# Defines locally interpretable model\n",
    "interp_model = Ridge(alpha = 1)\n",
    "\n",
    "# Checkpoint file name\n",
    "checkpoint_file_name = './tmp/model.ckpt'\n",
    "\n",
    "# Initializes RL-LIM\n",
    "rllim_class = rllim.Rllim(x_train, y_train_hat, x_probe, y_probe_hat, parameters, \n",
    "                          interp_model, baseline, checkpoint_file_name)\n",
    "\n",
    "# Trains RL-LIM\n",
    "rllim_class.rllim_train()\n",
    "\n",
    "print('Finished instance-wise weight estimator training.')\n",
    "\n",
    "## Output functions\n",
    "# Instance-wise weight estimation for x_test[0, :]\n",
    "dve_out = rllim_class.instancewise_weight_estimator(x_train, y_train_hat, x_test[0, :])\n",
    "\n",
    "# Interpretable predictions (test_y_fit) and instance-wise explanations (test_coef) for x_test[:0, :]\n",
    "test_y_fit, test_coef = rllim_class.rllim_interpreter(x_train, y_train_hat, x_test[0, :], interp_model)\n",
    "\n",
    "print('Finished instance-wise weight estimations, instance-wise predictions, and local explanations.')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 4: Interpretable inference\n",
    "\n",
    "Unlike Step 3 (training instance-wise weight estimator), we use a fixed instance-wise weight estimator (without the sampler and interpretable baseline) and merely fit the locally interpretable model at inference. Given the test instance, we obtain the selection probabilities from the instance-wise weight estimator, and using these as the weights, we fit the locally interpretable model via weighted optimization. \n",
    "\n",
    "1. **Input**: \n",
    " * Locally interpretable model: Ridge regression (we can switch this to shallow tree). The model must have fit, predict (for regression) and predict_proba (for classification) as the subfunctions.\n",
    " \n",
    " \n",
    "2. **Output**:\n",
    " * Instance-wise explanations (test_coef): Estimated local dynamics for testing samples using trained locally interpretable model.\n",
    " * Interpretable predictions (test_y_fit): Local predictions for testing samples using trained locally interpretable model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Restoring parameters from ./tmp/model.ckpt\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 900/900 [15:26<00:00,  1.03s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Finished instance-wise predictions and local explanations.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# Train locally interpretable models and output instance-wise explanations (test_coef) and\n",
    "# interpretable predictions (test_y_fit) \n",
    "test_y_fit, test_coef = rllim_class.rllim_interpreter(x_train, y_train_hat, x_test, interp_model)\n",
    "\n",
    "print('Finished instance-wise predictions and local explanations.')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluation\n",
    "\n",
    "We use two quantitative metrics (overall performance and fidelity) and one qualitative metric (instance-wise explanations) to evaluate the locally interpretable models.\n",
    "\n",
    " * Overall_performance: Difference between ground truth labels (y_test) and interpretable predictions (test_y_fit). \n",
    " * Fidelity: Difference between black-box model predictions (y_test_hat) and interpretable predictions (test_y_fit).\n",
    " * Instance-wise explanations: Qualitatively show the examples of instance-wise explanations."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1. Overall performance\n",
    "\n",
    " * We use Mean Absolute Error (MAE) as the metric for the overall performance. However, users can replace MAE to RMSE or others."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overall performance of RL-LIM in terms of MAE: 24.7097\n"
     ]
    }
   ],
   "source": [
    "# Overall performance\n",
    "mae = overall_performance_metrics (y_test, test_y_fit, metric='mae')\n",
    "print('Overall performance of RL-LIM in terms of MAE: ' + str(np.round(mae, 4)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2. Fidelity\n",
    "\n",
    " * We use R2 score and Mean Absolute Error (MAE) as the metrics for the fidelity. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fidelity of RL-LIM in terms of MAE: 20.5316\n",
      "Fidelity of RL-LIM in terms of R2 Score: 0.433\n"
     ]
    }
   ],
   "source": [
    "# Black-box model predictions\n",
    "y_test_hat = bb_model.predict(x_test)\n",
    "\n",
    "# Fidelity in terms of MAE\n",
    "mae = fidelity_metrics (y_test_hat, test_y_fit, metric='mae')\n",
    "print('Fidelity of RL-LIM in terms of MAE: ' + str(np.round(mae, 4)))\n",
    "\n",
    "# Fidelity in terms of R2 Score\n",
    "r2 = fidelity_metrics (y_test_hat, test_y_fit, metric='r2')\n",
    "print('Fidelity of RL-LIM in terms of R2 Score: ' + str(np.round(r2, 4)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3. Instance-wise explanations\n",
    "\n",
    " * We qualitatively demonstrate the local explanations of 5 testing samples using the fitted coefficients of locally interpretable model (Ridge regression).\n",
    " * To run this cell, the interpretable model must have intercept_ and coef_ as the subfunctions. Here, intercept and coef represent the fitted locally interpretable model's intercept and coefficients."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>intercept</th>\n",
       "      <th>Page Popularity/likes</th>\n",
       "      <th>Page Check</th>\n",
       "      <th>Page talk</th>\n",
       "      <th>Page Category</th>\n",
       "      <th>min # of comments</th>\n",
       "      <th>min # of comments in last 24 hours</th>\n",
       "      <th>min # of comments in last 48 hours</th>\n",
       "      <th>min # of comments in the first 24 hours</th>\n",
       "      <th>min # of comments in last 48 to last 24 hours</th>\n",
       "      <th>...</th>\n",
       "      <th>post was published on Thursday</th>\n",
       "      <th>post was published on Friday</th>\n",
       "      <th>post was published on Saturday</th>\n",
       "      <th>basetime (Sunday)</th>\n",
       "      <th>basetime (Monday)</th>\n",
       "      <th>basetime (Tuesday)</th>\n",
       "      <th>basetime (Wednesday)</th>\n",
       "      <th>basetime (Thursday)</th>\n",
       "      <th>basetime (Friday)</th>\n",
       "      <th>basetime (Saturday)</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>-103.330512</td>\n",
       "      <td>4.422585</td>\n",
       "      <td>-15.001978</td>\n",
       "      <td>4.496328</td>\n",
       "      <td>1.229217</td>\n",
       "      <td>4.497695</td>\n",
       "      <td>18.962609</td>\n",
       "      <td>32.218586</td>\n",
       "      <td>29.317612</td>\n",
       "      <td>25.755560</td>\n",
       "      <td>...</td>\n",
       "      <td>0.295788</td>\n",
       "      <td>-0.005714</td>\n",
       "      <td>-0.559625</td>\n",
       "      <td>0.078723</td>\n",
       "      <td>0.610730</td>\n",
       "      <td>0.452924</td>\n",
       "      <td>0.481491</td>\n",
       "      <td>-0.878218</td>\n",
       "      <td>-0.243942</td>\n",
       "      <td>-0.501708</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>-104.206197</td>\n",
       "      <td>4.253650</td>\n",
       "      <td>-14.915397</td>\n",
       "      <td>3.943046</td>\n",
       "      <td>1.137022</td>\n",
       "      <td>4.459325</td>\n",
       "      <td>18.666929</td>\n",
       "      <td>32.591710</td>\n",
       "      <td>29.570496</td>\n",
       "      <td>26.151958</td>\n",
       "      <td>...</td>\n",
       "      <td>0.130149</td>\n",
       "      <td>0.015970</td>\n",
       "      <td>-0.533755</td>\n",
       "      <td>0.159478</td>\n",
       "      <td>0.585260</td>\n",
       "      <td>0.235655</td>\n",
       "      <td>0.421461</td>\n",
       "      <td>-0.735738</td>\n",
       "      <td>-0.339109</td>\n",
       "      <td>-0.327006</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>-103.286616</td>\n",
       "      <td>4.186166</td>\n",
       "      <td>-14.795392</td>\n",
       "      <td>3.534093</td>\n",
       "      <td>1.126795</td>\n",
       "      <td>4.319873</td>\n",
       "      <td>18.569088</td>\n",
       "      <td>32.592334</td>\n",
       "      <td>29.231428</td>\n",
       "      <td>26.181185</td>\n",
       "      <td>...</td>\n",
       "      <td>0.176272</td>\n",
       "      <td>-0.096792</td>\n",
       "      <td>-0.562790</td>\n",
       "      <td>0.202239</td>\n",
       "      <td>0.594071</td>\n",
       "      <td>0.248880</td>\n",
       "      <td>0.311070</td>\n",
       "      <td>-0.700757</td>\n",
       "      <td>-0.387459</td>\n",
       "      <td>-0.268045</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>-101.693317</td>\n",
       "      <td>4.348262</td>\n",
       "      <td>-14.905574</td>\n",
       "      <td>5.691516</td>\n",
       "      <td>1.223091</td>\n",
       "      <td>4.993201</td>\n",
       "      <td>18.508501</td>\n",
       "      <td>31.850354</td>\n",
       "      <td>29.686880</td>\n",
       "      <td>23.989133</td>\n",
       "      <td>...</td>\n",
       "      <td>0.204816</td>\n",
       "      <td>-0.003839</td>\n",
       "      <td>-0.530760</td>\n",
       "      <td>0.177499</td>\n",
       "      <td>0.659650</td>\n",
       "      <td>0.291278</td>\n",
       "      <td>0.422315</td>\n",
       "      <td>-0.789438</td>\n",
       "      <td>-0.333941</td>\n",
       "      <td>-0.427364</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>-101.229288</td>\n",
       "      <td>4.318020</td>\n",
       "      <td>-14.496846</td>\n",
       "      <td>5.550790</td>\n",
       "      <td>1.209830</td>\n",
       "      <td>4.939441</td>\n",
       "      <td>18.472585</td>\n",
       "      <td>31.814553</td>\n",
       "      <td>29.207907</td>\n",
       "      <td>24.082570</td>\n",
       "      <td>...</td>\n",
       "      <td>0.230303</td>\n",
       "      <td>-0.065082</td>\n",
       "      <td>-0.528969</td>\n",
       "      <td>0.089551</td>\n",
       "      <td>0.525554</td>\n",
       "      <td>0.416373</td>\n",
       "      <td>0.367239</td>\n",
       "      <td>-0.741127</td>\n",
       "      <td>-0.354910</td>\n",
       "      <td>-0.302679</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 54 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "    intercept  Page Popularity/likes  Page Check  Page talk  Page Category  \\\n",
       "0 -103.330512               4.422585  -15.001978   4.496328       1.229217   \n",
       "1 -104.206197               4.253650  -14.915397   3.943046       1.137022   \n",
       "2 -103.286616               4.186166  -14.795392   3.534093       1.126795   \n",
       "3 -101.693317               4.348262  -14.905574   5.691516       1.223091   \n",
       "4 -101.229288               4.318020  -14.496846   5.550790       1.209830   \n",
       "\n",
       "   min # of comments  min # of comments in last 24 hours  \\\n",
       "0           4.497695                           18.962609   \n",
       "1           4.459325                           18.666929   \n",
       "2           4.319873                           18.569088   \n",
       "3           4.993201                           18.508501   \n",
       "4           4.939441                           18.472585   \n",
       "\n",
       "   min # of comments in last 48 hours  \\\n",
       "0                           32.218586   \n",
       "1                           32.591710   \n",
       "2                           32.592334   \n",
       "3                           31.850354   \n",
       "4                           31.814553   \n",
       "\n",
       "   min # of comments in the first 24 hours  \\\n",
       "0                                29.317612   \n",
       "1                                29.570496   \n",
       "2                                29.231428   \n",
       "3                                29.686880   \n",
       "4                                29.207907   \n",
       "\n",
       "   min # of comments in last 48 to last 24 hours  ...  \\\n",
       "0                                      25.755560  ...   \n",
       "1                                      26.151958  ...   \n",
       "2                                      26.181185  ...   \n",
       "3                                      23.989133  ...   \n",
       "4                                      24.082570  ...   \n",
       "\n",
       "   post was published on Thursday  post was published on Friday  \\\n",
       "0                        0.295788                     -0.005714   \n",
       "1                        0.130149                      0.015970   \n",
       "2                        0.176272                     -0.096792   \n",
       "3                        0.204816                     -0.003839   \n",
       "4                        0.230303                     -0.065082   \n",
       "\n",
       "   post was published on Saturday  basetime (Sunday)  basetime (Monday)  \\\n",
       "0                       -0.559625           0.078723           0.610730   \n",
       "1                       -0.533755           0.159478           0.585260   \n",
       "2                       -0.562790           0.202239           0.594071   \n",
       "3                       -0.530760           0.177499           0.659650   \n",
       "4                       -0.528969           0.089551           0.525554   \n",
       "\n",
       "   basetime (Tuesday)  basetime (Wednesday)  basetime (Thursday)  \\\n",
       "0            0.452924              0.481491            -0.878218   \n",
       "1            0.235655              0.421461            -0.735738   \n",
       "2            0.248880              0.311070            -0.700757   \n",
       "3            0.291278              0.422315            -0.789438   \n",
       "4            0.416373              0.367239            -0.741127   \n",
       "\n",
       "   basetime (Friday)  basetime (Saturday)  \n",
       "0          -0.243942            -0.501708  \n",
       "1          -0.339109            -0.327006  \n",
       "2          -0.387459            -0.268045  \n",
       "3          -0.333941            -0.427364  \n",
       "4          -0.354910            -0.302679  \n",
       "\n",
       "[5 rows x 54 columns]"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Local explanations of n samples\n",
    "n = 5\n",
    "local_explanations = test_coef[:n, :]\n",
    "\n",
    "# Make pandas dataframe\n",
    "final_col_names = np.concatenate((np.asarray(['intercept']), col_names), axis = 0) \n",
    "pd.DataFrame(data=local_explanations, index=range(n), columns=final_col_names)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
